{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"NLU_traing_multi_label_classifier_E2e.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"markdown","metadata":{"id":"zkufh760uvF3"},"source":["![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)\n","\n","[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/nlu/blob/master/examples/colab/Training/multi_label_text_classification/NLU_traing_multi_label_classifier_E2e.ipynb)\n","\n","\n","\n","# Training a Deep Learning Classifier for multi label prediction\n","MultiClassifierDL is a Multi-label Text Classification. MultiClassifierDL uses a Bidirectional GRU with Convolution model that we have built inside TensorFlow and supports up to 100 classes. The input to MultiClassifierDL is Sentence Embeddings such as state-of-the-art UniversalSentenceEncoder, BertSentenceEmbeddings, or SentenceEmbeddings\n","\n","\n","\n","### Multi ClassifierDL (Multi-class Text Classification with multiple classes per sentence)\n","With the [ClassifierDL model](https://nlp.johnsnowlabs.com/docs/en/annotators#multiclassifierdl-multi-label-text-classification) from Spark NLP you can achieve State Of the Art results on any multi class text classification problem \n","\n","This notebook showcases the following features : \n","\n","- How to train the deep learning classifier\n","- How to store a pipeline to disk\n","- How to load the pipeline from disk (Enables NLU offline mode)\n","\n"]},{"cell_type":"markdown","metadata":{"id":"dur2drhW5Rvi"},"source":["# 1. Install Java 8 and NLU"]},{"cell_type":"code","metadata":{"id":"hFGnBCHavltY"},"source":["!wget https://setup.johnsnowlabs.com/nlu/colab.sh -O - | bash\n","  \n","\n","import nlu"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"f4KkTfnR5Ugg"},"source":["# 2. Download E2E Challenge multi token label classification dataset\n","\n","http://www.macs.hw.ac.uk/InteractionLab/E2E/"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":586},"id":"y4xSRWIhwT28","executionInfo":{"status":"ok","timestamp":1609529840956,"user_tz":-60,"elapsed":160088,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"39519c61-f3a4-4369-f72a-1f0590d9bb2e"},"source":["import pandas as pd\n","!wget http://ckl-it.de/wp-content/uploads/2020/12/e2e.csv\n","test_path = '/content/e2e.csv'\n","train_df = pd.read_csv(test_path)\n","train_df = train_df.iloc[:3000]\n","train_df"],"execution_count":null,"outputs":[{"output_type":"stream","text":["--2021-01-01 19:37:17--  http://ckl-it.de/wp-content/uploads/2020/12/e2e.csv\n","Resolving ckl-it.de (ckl-it.de)... 217.160.0.108, 2001:8d8:100f:f000::209\n","Connecting to ckl-it.de (ckl-it.de)|217.160.0.108|:80... connected.\n","HTTP request sent, awaiting response... 200 OK\n","Length: 1322591 (1.3M) [text/csv]\n","Saving to: ‘e2e.csv’\n","\n","e2e.csv             100%[===================>]   1.26M   715KB/s    in 1.8s    \n","\n","2021-01-01 19:37:20 (715 KB/s) - ‘e2e.csv’ saved [1322591/1322591]\n","\n"],"name":"stdout"},{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>Unnamed: 0</th>\n","      <th>y</th>\n","      <th>text</th>\n","      <th>origin_index</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>0</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[cit...</td>\n","      <td>A coffee shop in the city centre area called B...</td>\n","      <td>0</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>1</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[cit...</td>\n","      <td>Blue Spice is a coffee shop in city centre.</td>\n","      <td>1</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>2</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[riv...</td>\n","      <td>There is a coffee shop Blue Spice in the river...</td>\n","      <td>2</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>3</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[riv...</td>\n","      <td>At the riverside, there is a coffee shop calle...</td>\n","      <td>3</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>4</td>\n","      <td>name[Blue Spice],eatType[coffee shop],customer...</td>\n","      <td>The coffee shop Blue Spice is based near Crown...</td>\n","      <td>4</td>\n","    </tr>\n","    <tr>\n","      <th>...</th>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","    </tr>\n","    <tr>\n","      <th>2995</th>\n","      <td>2995</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>Near Express by Holiday Inn, in the riverside ...</td>\n","      <td>2995</td>\n","    </tr>\n","    <tr>\n","      <th>2996</th>\n","      <td>2996</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>In the riverside area, near Express by Holiday...</td>\n","      <td>2996</td>\n","    </tr>\n","    <tr>\n","      <th>2997</th>\n","      <td>2997</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>The Punter is a restaurant with Indian food in...</td>\n","      <td>2997</td>\n","    </tr>\n","    <tr>\n","      <th>2998</th>\n","      <td>2998</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>The Punter is a low rated restaurant that serv...</td>\n","      <td>2998</td>\n","    </tr>\n","    <tr>\n","      <th>2999</th>\n","      <td>2999</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","      <td>2999</td>\n","    </tr>\n","  </tbody>\n","</table>\n","<p>3000 rows × 4 columns</p>\n","</div>"],"text/plain":["      Unnamed: 0  ... origin_index\n","0              0  ...            0\n","1              1  ...            1\n","2              2  ...            2\n","3              3  ...            3\n","4              4  ...            4\n","...          ...  ...          ...\n","2995        2995  ...         2995\n","2996        2996  ...         2996\n","2997        2997  ...         2997\n","2998        2998  ...         2998\n","2999        2999  ...         2999\n","\n","[3000 rows x 4 columns]"]},"metadata":{"tags":[]},"execution_count":2}]},{"cell_type":"markdown","metadata":{"id":"0296Om2C5anY"},"source":["# 3. Train Deep Learning Classifier using nlu.load('train.multi_classifier')\n","\n","By default, the Universal Sentence Encoder Embeddings (USE) are beeing downloaded to provide embeddings for the classifier. You can use any of the 50+ other sentence Emeddings in NLU tough!\n","\n","You dataset label column should be named 'y' and the feature column with text data should be named 'text'"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":471},"id":"3ZIPkRkWftBG","executionInfo":{"status":"ok","timestamp":1609522208492,"user_tz":-60,"elapsed":410284,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"bda58bd4-d56e-471c-deea-37fe6e06af5e"},"source":["import nlu\n","# load a trainable pipeline by specifying the train  prefix \n","unfitted_pipe = nlu.load('train.multi_classifier')\n","#configure epochs\n","unfitted_pipe['trainable_multi_classifier_dl'].setMaxEpochs(25)\n","#  fit it on a datset with label='y' and text columns. Labels seperated by ','\n","fitted_pipe = unfitted_pipe.fit(train_df[['y','text']], label_seperator=',')\n","\n","# predict with the trained pipeline on dataset and get predictions\n","preds = fitted_pipe.predict(train_df[['y','text']])\n","preds"],"execution_count":null,"outputs":[{"output_type":"stream","text":["tfhub_use download started this may take some time.\n","Approximate size to download 923.7 MB\n","[OK!]\n"],"name":"stdout"},{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>multi_classifier_classes</th>\n","      <th>multi_classifier_confidences</th>\n","      <th>default_name_embeddings</th>\n","      <th>y</th>\n","      <th>sentence</th>\n","      <th>text</th>\n","    </tr>\n","    <tr>\n","      <th>origin_index</th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>[near[Café Rouge], name[Blue Spice], near[Rain...</td>\n","      <td>[0.8555223, 0.99276984, 0.87128675, 0.9852337,...</td>\n","      <td>[0.026563657447695732, -0.058662936091423035, ...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[cit...</td>\n","      <td>A coffee shop in the city centre area called B...</td>\n","      <td>A coffee shop in the city centre area called B...</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>[near[Café Rouge], name[Blue Spice], near[Rain...</td>\n","      <td>[0.8142674, 0.99920505, 0.93413615, 0.98056525...</td>\n","      <td>[0.040952689945697784, -0.04276810586452484, -...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[cit...</td>\n","      <td>Blue Spice is a coffee shop in city centre.</td>\n","      <td>Blue Spice is a coffee shop in city centre.</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>[name[Blue Spice], near[Rainbow Vegetarian Caf...</td>\n","      <td>[0.9966337, 0.9044244, 0.904881, 0.56231284, 0...</td>\n","      <td>[0.03141527622938156, -0.05154882371425629, 0....</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[riv...</td>\n","      <td>There is a coffee shop Blue Spice in the river...</td>\n","      <td>There is a coffee shop Blue Spice in the river...</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>[near[Café Rouge], name[Blue Spice], near[Rain...</td>\n","      <td>[0.5227911, 0.99917483, 0.9394022, 0.8839797, ...</td>\n","      <td>[0.03584946319460869, -0.036898739635944366, -...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[riv...</td>\n","      <td>At the riverside, there is a coffee shop calle...</td>\n","      <td>At the riverside, there is a coffee shop calle...</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>[near[Café Rouge], name[Blue Spice], near[Crow...</td>\n","      <td>[0.5985904, 0.7892299, 0.8222753, 0.9378743, 0...</td>\n","      <td>[0.0405426099896431, -0.0243277158588171, 0.00...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],customer...</td>\n","      <td>The coffee shop Blue Spice is based near Crown...</td>\n","      <td>The coffee shop Blue Spice is based near Crown...</td>\n","    </tr>\n","    <tr>\n","      <th>...</th>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","    </tr>\n","    <tr>\n","      <th>2998</th>\n","      <td>[near[Express by Holiday Inn], priceRange[high...</td>\n","      <td>[0.9999982, 0.8146039, 0.99978125, 0.8511795, ...</td>\n","      <td>[0.05956212058663368, 0.019028551876544952, -0...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>The Punter has a price range of less than £20,...</td>\n","      <td>The Punter is a low rated restaurant that serv...</td>\n","    </tr>\n","    <tr>\n","      <th>2999</th>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>[0.99992794, 0.99981034, 0.5099642, 0.9994041,...</td>\n","      <td>[0.04296032711863518, -0.0015949805965647101, ...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","    </tr>\n","    <tr>\n","      <th>2999</th>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>[0.99992794, 0.99981034, 0.5099642, 0.9994041,...</td>\n","      <td>[0.023289771750569344, 0.056861914694309235, -...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>It is located in the riverside.</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","    </tr>\n","    <tr>\n","      <th>2999</th>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>[0.99992794, 0.99981034, 0.5099642, 0.9994041,...</td>\n","      <td>[0.033101629465818405, 0.06402800232172012, 0....</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>It is near Express by Holiday Inn.</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","    </tr>\n","    <tr>\n","      <th>2999</th>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>[0.99992794, 0.99981034, 0.5099642, 0.9994041,...</td>\n","      <td>[0.01677701249718666, 0.04876527190208435, -0....</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>Its customer rating is low.</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","    </tr>\n","  </tbody>\n","</table>\n","<p>5266 rows × 6 columns</p>\n","</div>"],"text/plain":["                                       multi_classifier_classes  ...                                               text\n","origin_index                                                     ...                                                   \n","0             [near[Café Rouge], name[Blue Spice], near[Rain...  ...  A coffee shop in the city centre area called B...\n","1             [near[Café Rouge], name[Blue Spice], near[Rain...  ...        Blue Spice is a coffee shop in city centre.\n","2             [name[Blue Spice], near[Rainbow Vegetarian Caf...  ...  There is a coffee shop Blue Spice in the river...\n","3             [near[Café Rouge], name[Blue Spice], near[Rain...  ...  At the riverside, there is a coffee shop calle...\n","4             [near[Café Rouge], name[Blue Spice], near[Crow...  ...  The coffee shop Blue Spice is based near Crown...\n","...                                                         ...  ...                                                ...\n","2998          [near[Express by Holiday Inn], priceRange[high...  ...  The Punter is a low rated restaurant that serv...\n","2999          [near[Express by Holiday Inn], food[Indian], c...  ...  The Punter is a restaurant providing Indian fo...\n","2999          [near[Express by Holiday Inn], food[Indian], c...  ...  The Punter is a restaurant providing Indian fo...\n","2999          [near[Express by Holiday Inn], food[Indian], c...  ...  The Punter is a restaurant providing Indian fo...\n","2999          [near[Express by Holiday Inn], food[Indian], c...  ...  The Punter is a restaurant providing Indian fo...\n","\n","[5266 rows x 6 columns]"]},"metadata":{"tags":[]},"execution_count":3}]},{"cell_type":"markdown","metadata":{"id":"DL_5aY9b3jSd"},"source":["# 4. Evaluate the model"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0YDA2KunCeqQ","executionInfo":{"status":"ok","timestamp":1609522209572,"user_tz":-60,"elapsed":411343,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"37539c88-d18c-425d-a28d-4127dc9bbb99"},"source":["from sklearn.preprocessing import MultiLabelBinarizer\n","from sklearn.metrics import classification_report\n","from sklearn.metrics import f1_score\n","from sklearn.metrics import roc_auc_score\n","preds.classifier_dl = preds.classifier_dl.astype(str)\n","mlb = MultiLabelBinarizer()\n","mlb = mlb.fit(preds.y.str.split(','))\n","y_true = mlb.transform(preds['y'].str.split(','))\n","y_pred = mlb.transform(preds.classifier_dl.str.join(',').str.split(','))\n","print(\"Classification report: \\n\", (classification_report(y_true, y_pred)))\n","print(\"F1 micro averaging:\",(f1_score(y_true, y_pred, average='micro')))\n","print(\"ROC: \",(roc_auc_score(y_true, y_pred, average=\"micro\")))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Classification report: \n","               precision    recall  f1-score   support\n","\n","           0       0.78      0.97      0.86      1700\n","           1       0.95      0.83      0.89      2914\n","           2       0.56      0.64      0.60       576\n","           3       0.33      0.28      0.30       367\n","           4       0.38      0.55      0.45       455\n","           5       0.30      0.76      0.42       599\n","           6       0.37      0.77      0.50       550\n","           7       0.69      0.44      0.54       457\n","           8       0.99      0.72      0.84       337\n","           9       0.91      0.98      0.95      2211\n","          10       0.89      0.99      0.94      2718\n","          11       0.53      0.89      0.67      1914\n","          12       0.88      0.79      0.84      3154\n","          13       0.79      0.98      0.87      1087\n","          14       0.69      0.97      0.81      1118\n","          15       0.98      0.64      0.78      1077\n","          16       0.82      0.96      0.88       671\n","          17       0.71      1.00      0.83       323\n","          18       0.57      0.65      0.61       130\n","          19       0.96      0.80      0.87       186\n","          20       0.77      0.99      0.87       366\n","          21       0.57      0.20      0.30        40\n","          22       0.36      0.10      0.15        42\n","          23       0.00      0.00      0.00         4\n","          24       0.97      0.97      0.97       322\n","          25       0.99      0.83      0.91       338\n","          26       0.00      0.00      0.00         6\n","          27       0.00      0.00      0.00        34\n","          28       0.94      0.99      0.96      1273\n","          29       0.96      1.00      0.98       987\n","          30       0.90      0.99      0.95      1140\n","          31       0.74      0.85      0.79       186\n","          32       0.45      0.98      0.62       528\n","          33       0.91      0.97      0.93       662\n","          34       0.90      0.60      0.72       116\n","          35       0.67      0.09      0.16        22\n","          36       0.58      0.98      0.73       484\n","          37       0.88      0.77      0.82       601\n","          38       0.94      0.97      0.96       711\n","          39       0.99      0.96      0.97       620\n","          40       0.96      0.99      0.98       526\n","          41       0.98      1.00      0.99      1410\n","          42       1.00      0.28      0.43        72\n","          43       0.00      0.00      0.00         8\n","          44       0.00      0.00      0.00         8\n","          45       0.00      0.00      0.00         4\n","          46       0.35      0.42      0.38       595\n","          47       0.34      0.66      0.45       849\n","          48       0.57      0.44      0.50       627\n","          49       0.69      0.53      0.60       767\n","          50       0.31      0.32      0.32       347\n","          51       0.25      0.53      0.34       453\n","\n","   micro avg       0.73      0.84      0.78     36692\n","   macro avg       0.64      0.65      0.62     36692\n","weighted avg       0.78      0.84      0.80     36692\n"," samples avg       0.76      0.84      0.79     36692\n","\n","F1 micro averaging: 0.7831856729396004\n","ROC:  0.8980818453315285\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"mhFKVN93o1ZO"},"source":["# 5. Lets try different Sentence Emebddings"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"CzJd8omao0gt","executionInfo":{"status":"ok","timestamp":1609522209573,"user_tz":-60,"elapsed":411328,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"ce35ce12-fbc8-4e0f-c9a1-6feaf68da7b0"},"source":["# We can use nlu.print_components(action='embed_sentence') to see every possibler sentence embedding we could use. Lets use bert!\n","nlu.print_components(action='embed_sentence')"],"execution_count":null,"outputs":[{"output_type":"stream","text":["For language <en> NLU provides the following Models : \n","nlu.load('en.embed_sentence') returns Spark NLP model tfhub_use\n","nlu.load('en.embed_sentence.use') returns Spark NLP model tfhub_use\n","nlu.load('en.embed_sentence.tfhub_use') returns Spark NLP model tfhub_use\n","nlu.load('en.embed_sentence.use.lg') returns Spark NLP model tfhub_use_lg\n","nlu.load('en.embed_sentence.tfhub_use.lg') returns Spark NLP model tfhub_use_lg\n","nlu.load('en.embed_sentence.albert') returns Spark NLP model albert_base_uncased\n","nlu.load('en.embed_sentence.electra') returns Spark NLP model sent_electra_small_uncased\n","nlu.load('en.embed_sentence.electra_small_uncased') returns Spark NLP model sent_electra_small_uncased\n","nlu.load('en.embed_sentence.electra_base_uncased') returns Spark NLP model sent_electra_base_uncased\n","nlu.load('en.embed_sentence.electra_large_uncased') returns Spark NLP model sent_electra_large_uncased\n","nlu.load('en.embed_sentence.bert') returns Spark NLP model sent_bert_base_uncased\n","nlu.load('en.embed_sentence.bert_base_uncased') returns Spark NLP model sent_bert_base_uncased\n","nlu.load('en.embed_sentence.bert_base_cased') returns Spark NLP model sent_bert_base_cased\n","nlu.load('en.embed_sentence.bert_large_uncased') returns Spark NLP model sent_bert_large_uncased\n","nlu.load('en.embed_sentence.bert_large_cased') returns Spark NLP model sent_bert_large_cased\n","nlu.load('en.embed_sentence.biobert.pubmed_base_cased') returns Spark NLP model sent_biobert_pubmed_base_cased\n","nlu.load('en.embed_sentence.biobert.pubmed_large_cased') returns Spark NLP model sent_biobert_pubmed_large_cased\n","nlu.load('en.embed_sentence.biobert.pmc_base_cased') returns Spark NLP model sent_biobert_pmc_base_cased\n","nlu.load('en.embed_sentence.biobert.pubmed_pmc_base_cased') returns Spark NLP model sent_biobert_pubmed_pmc_base_cased\n","nlu.load('en.embed_sentence.biobert.clinical_base_cased') returns Spark NLP model sent_biobert_clinical_base_cased\n","nlu.load('en.embed_sentence.biobert.discharge_base_cased') returns Spark NLP model sent_biobert_discharge_base_cased\n","nlu.load('en.embed_sentence.covidbert.large_uncased') returns Spark NLP model sent_covidbert_large_uncased\n","nlu.load('en.embed_sentence.small_bert_L2_128') returns Spark NLP model sent_small_bert_L2_128\n","nlu.load('en.embed_sentence.small_bert_L4_128') returns Spark NLP model sent_small_bert_L4_128\n","nlu.load('en.embed_sentence.small_bert_L6_128') returns Spark NLP model sent_small_bert_L6_128\n","nlu.load('en.embed_sentence.small_bert_L8_128') returns Spark NLP model sent_small_bert_L8_128\n","nlu.load('en.embed_sentence.small_bert_L10_128') returns Spark NLP model sent_small_bert_L10_128\n","nlu.load('en.embed_sentence.small_bert_L12_128') returns Spark NLP model sent_small_bert_L12_128\n","nlu.load('en.embed_sentence.small_bert_L2_256') returns Spark NLP model sent_small_bert_L2_256\n","nlu.load('en.embed_sentence.small_bert_L4_256') returns Spark NLP model sent_small_bert_L4_256\n","nlu.load('en.embed_sentence.small_bert_L6_256') returns Spark NLP model sent_small_bert_L6_256\n","nlu.load('en.embed_sentence.small_bert_L8_256') returns Spark NLP model sent_small_bert_L8_256\n","nlu.load('en.embed_sentence.small_bert_L10_256') returns Spark NLP model sent_small_bert_L10_256\n","nlu.load('en.embed_sentence.small_bert_L12_256') returns Spark NLP model sent_small_bert_L12_256\n","nlu.load('en.embed_sentence.small_bert_L2_512') returns Spark NLP model sent_small_bert_L2_512\n","nlu.load('en.embed_sentence.small_bert_L4_512') returns Spark NLP model sent_small_bert_L4_512\n","nlu.load('en.embed_sentence.small_bert_L6_512') returns Spark NLP model sent_small_bert_L6_512\n","nlu.load('en.embed_sentence.small_bert_L8_512') returns Spark NLP model sent_small_bert_L8_512\n","nlu.load('en.embed_sentence.small_bert_L10_512') returns Spark NLP model sent_small_bert_L10_512\n","nlu.load('en.embed_sentence.small_bert_L12_512') returns Spark NLP model sent_small_bert_L12_512\n","nlu.load('en.embed_sentence.small_bert_L2_768') returns Spark NLP model sent_small_bert_L2_768\n","nlu.load('en.embed_sentence.small_bert_L4_768') returns Spark NLP model sent_small_bert_L4_768\n","nlu.load('en.embed_sentence.small_bert_L6_768') returns Spark NLP model sent_small_bert_L6_768\n","nlu.load('en.embed_sentence.small_bert_L8_768') returns Spark NLP model sent_small_bert_L8_768\n","nlu.load('en.embed_sentence.small_bert_L10_768') returns Spark NLP model sent_small_bert_L10_768\n","nlu.load('en.embed_sentence.small_bert_L12_768') returns Spark NLP model sent_small_bert_L12_768\n","For language <fi> NLU provides the following Models : \n","nlu.load('fi.embed_sentence') returns Spark NLP model sent_bert_finnish_cased\n","nlu.load('fi.embed_sentence.bert.cased') returns Spark NLP model sent_bert_finnish_cased\n","nlu.load('fi.embed_sentence.bert.uncased') returns Spark NLP model sent_bert_finnish_uncased\n","For language <xx> NLU provides the following Models : \n","nlu.load('xx.embed_sentence') returns Spark NLP model sent_bert_multi_cased\n","nlu.load('xx.embed_sentence.bert') returns Spark NLP model sent_bert_multi_cased\n","nlu.load('xx.embed_sentence.bert.cased') returns Spark NLP model sent_bert_multi_cased\n","nlu.load('xx.embed_sentence.labse') returns Spark NLP model labse\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0ofYHpu7sloS","executionInfo":{"status":"ok","timestamp":1609529895586,"user_tz":-60,"elapsed":54621,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"44154b28-c1db-4f58-bab1-7ac185fa40b8"},"source":["# You might need to restart your notebook to clear RAM, or you might run out of Memory when fitting\n","import nlu\n","pipe = nlu.load('en.embed_sentence.small_bert_L12_768 train.multi_classifier')\n","pipe.print_info()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["sent_small_bert_L12_768 download started this may take some time.\n","Approximate size to download 392.9 MB\n","[OK!]\n","The following parameters are configurable for this NLU pipeline (You can copy paste the examples) :\n",">>> pipe['en_embed_sentence_small_bert_L12_768'] has settable params:\n","pipe['en_embed_sentence_small_bert_L12_768'].setBatchSize(32)  | Info: Batch size. Large values allows faster processing but requires more memory. | Currently set to : 32\n","pipe['en_embed_sentence_small_bert_L12_768'].setIsLong(False)  | Info: Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int. | Currently set to : False\n","pipe['en_embed_sentence_small_bert_L12_768'].setMaxSentenceLength(128)  | Info: Max sentence length to process | Currently set to : 128\n","pipe['en_embed_sentence_small_bert_L12_768'].setDimension(768)  | Info: Number of embedding dimensions | Currently set to : 768\n","pipe['en_embed_sentence_small_bert_L12_768'].setCaseSensitive(False)  | Info: whether to ignore case in tokens for embeddings matching | Currently set to : False\n","pipe['en_embed_sentence_small_bert_L12_768'].setStorageRef('sent_small_bert_L12_768')  | Info: unique reference name for identification | Currently set to : sent_small_bert_L12_768\n",">>> pipe['sentence_detector'] has settable params:\n","pipe['sentence_detector'].setUseAbbreviations(True)  | Info: whether to apply abbreviations at sentence detection | Currently set to : True\n","pipe['sentence_detector'].setDetectLists(True)      | Info: whether detect lists during sentence detection | Currently set to : True\n","pipe['sentence_detector'].setUseCustomBoundsOnly(False)  | Info: Only utilize custom bounds in sentence detection | Currently set to : False\n","pipe['sentence_detector'].setCustomBounds([])       | Info: characters used to explicitly mark sentence bounds | Currently set to : []\n","pipe['sentence_detector'].setExplodeSentences(False)  | Info: whether to explode each sentence into a different row, for better parallelization. Defaults to false. | Currently set to : False\n","pipe['sentence_detector'].setMinLength(0)           | Info: Set the minimum allowed length for each sentence. | Currently set to : 0\n","pipe['sentence_detector'].setMaxLength(99999)       | Info: Set the maximum allowed length for each sentence | Currently set to : 99999\n",">>> pipe['default_tokenizer'] has settable params:\n","pipe['default_tokenizer'].setTargetPattern('\\S+')   | Info: pattern to grab from text as token candidates. Defaults \\S+ | Currently set to : \\S+\n","pipe['default_tokenizer'].setContextChars(['.', ',', ';', ':', '!', '?', '*', '-', '(', ')', '\"', \"'\"])  | Info: character list used to separate from token boundaries | Currently set to : ['.', ',', ';', ':', '!', '?', '*', '-', '(', ')', '\"', \"'\"]\n","pipe['default_tokenizer'].setCaseSensitiveExceptions(True)  | Info: Whether to care for case sensitiveness in exceptions | Currently set to : True\n","pipe['default_tokenizer'].setMinLength(0)           | Info: Set the minimum allowed legth for each token | Currently set to : 0\n","pipe['default_tokenizer'].setMaxLength(99999)       | Info: Set the maximum allowed legth for each token | Currently set to : 99999\n",">>> pipe['document_assembler'] has settable params:\n","pipe['document_assembler'].setCleanupMode('shrink')  | Info: possible values: disabled, inplace, inplace_full, shrink, shrink_full, each, each_full, delete_full | Currently set to : shrink\n",">>> pipe['multi_classifier'] has settable params:\n","pipe['multi_classifier'].setMaxEpochs(2)            | Info: Maximum number of epochs to train | Currently set to : 2\n","pipe['multi_classifier'].setLr(0.001)               | Info: Learning Rate | Currently set to : 0.001\n","pipe['multi_classifier'].setBatchSize(64)           | Info: Batch size | Currently set to : 64\n","pipe['multi_classifier'].setValidationSplit(0.0)    | Info: Choose the proportion of training dataset to be validated against the model on each Epoch. The value should be between 0.0 and 1.0 and by default it is 0.0 and off. | Currently set to : 0.0\n","pipe['multi_classifier'].setThreshold(0.5)          | Info: The minimum threshold for each label to be accepted. Default is 0.5 | Currently set to : 0.5\n","pipe['multi_classifier'].setRandomSeed(44)          | Info: Random seed | Currently set to : 44\n","pipe['multi_classifier'].setShufflePerEpoch(False)  | Info: whether to shuffle the training data on each Epoch | Currently set to : False\n","pipe['multi_classifier'].setEnableOutputLogs(True)  | Info: Whether to use stdout in addition to Spark logs. | Currently set to : True\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"ABHLgirmG1n9","colab":{"base_uri":"https://localhost:8080/","height":417},"executionInfo":{"status":"ok","timestamp":1609531977887,"user_tz":-60,"elapsed":2136903,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"d312277d-3826-46e2-c67e-4a10a7116c4f"},"source":["\n","# Load pipe with bert embeds and configure hyper parameters\n","# using large embeddings can take a few hours..\n","pipe['trainable_multi_classifier_dl'].setMaxEpochs(100)            \n","pipe['trainable_multi_classifier_dl'].setLr(0.0005)  \n","fitted_pipe = pipe.fit(train_df[['y','text']],label_seperator=',')\n","preds = fitted_pipe.predict(train_df)\n","preds"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>text</th>\n","      <th>multi_classifier_classes</th>\n","      <th>Unnamed: 0</th>\n","      <th>document</th>\n","      <th>y</th>\n","      <th>multi_classifier_confidences</th>\n","      <th>en_embed_sentence_small_bert_L12_768_embeddings</th>\n","    </tr>\n","    <tr>\n","      <th>origin_index</th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>A coffee shop in the city centre area called B...</td>\n","      <td>[name[Blue Spice], eatType[coffee shop], area[...</td>\n","      <td>0</td>\n","      <td>A coffee shop in the city centre area called B...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[cit...</td>\n","      <td>[0.9740321, 0.99538183, 0.92562413]</td>\n","      <td>[-0.1427491158246994, 0.5036071538925171, 0.07...</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>Blue Spice is a coffee shop in city centre.</td>\n","      <td>[name[Blue Spice], eatType[coffee shop], area[...</td>\n","      <td>1</td>\n","      <td>Blue Spice is a coffee shop in city centre.</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[cit...</td>\n","      <td>[0.9950888, 0.9989519, 0.8684354]</td>\n","      <td>[-0.20697341859340668, 0.5286431312561035, 0.2...</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>There is a coffee shop Blue Spice in the river...</td>\n","      <td>[name[Blue Spice], eatType[coffee shop], area[...</td>\n","      <td>2</td>\n","      <td>There is a coffee shop Blue Spice in the river...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[riv...</td>\n","      <td>[0.95310336, 0.9655487, 0.9785502]</td>\n","      <td>[0.005826675333082676, 0.49930453300476074, -0...</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>At the riverside, there is a coffee shop calle...</td>\n","      <td>[name[Blue Spice], eatType[coffee shop], area[...</td>\n","      <td>3</td>\n","      <td>At the riverside, there is a coffee shop calle...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],area[riv...</td>\n","      <td>[0.8858954, 0.931189, 0.9990605]</td>\n","      <td>[0.12191159278154373, 0.37966835498809814, 0.0...</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>The coffee shop Blue Spice is based near Crown...</td>\n","      <td>[near[Crowne Plaza Hotel], customer rating[5 o...</td>\n","      <td>4</td>\n","      <td>The coffee shop Blue Spice is based near Crown...</td>\n","      <td>name[Blue Spice],eatType[coffee shop],customer...</td>\n","      <td>[0.99912286, 0.7930833, 0.9730882]</td>\n","      <td>[-0.37350592017173767, 0.1885937601327896, 0.1...</td>\n","    </tr>\n","    <tr>\n","      <th>...</th>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","      <td>...</td>\n","    </tr>\n","    <tr>\n","      <th>2995</th>\n","      <td>Near Express by Holiday Inn, in the riverside ...</td>\n","      <td>[near[Express by Holiday Inn], customer rating...</td>\n","      <td>2995</td>\n","      <td>Near Express by Holiday Inn, in the riverside ...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>[0.9476669, 0.9914391, 0.8395983, 0.98047745, ...</td>\n","      <td>[0.0485222227871418, 0.2381688505411148, 0.227...</td>\n","    </tr>\n","    <tr>\n","      <th>2996</th>\n","      <td>In the riverside area, near Express by Holiday...</td>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>2996</td>\n","      <td>In the riverside area, near Express by Holiday...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>[0.94435394, 0.6119035, 0.7891044, 0.9885667, ...</td>\n","      <td>[0.06879807263612747, 0.23580998182296753, 0.1...</td>\n","    </tr>\n","    <tr>\n","      <th>2997</th>\n","      <td>The Punter is a restaurant with Indian food in...</td>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>2997</td>\n","      <td>The Punter is a restaurant with Indian food in...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>[0.99509084, 0.9424925, 0.7625178, 0.9907007, ...</td>\n","      <td>[-0.12667560577392578, 0.22056235373020172, 0....</td>\n","    </tr>\n","    <tr>\n","      <th>2998</th>\n","      <td>The Punter is a low rated restaurant that serv...</td>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>2998</td>\n","      <td>The Punter is a low rated restaurant that serv...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>[0.99541605, 0.9715836, 0.87202764, 0.99880993...</td>\n","      <td>[-0.13057495653629303, 0.21937601268291473, 0....</td>\n","    </tr>\n","    <tr>\n","      <th>2999</th>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","      <td>[near[Express by Holiday Inn], food[Indian], c...</td>\n","      <td>2999</td>\n","      <td>The Punter is a restaurant providing Indian fo...</td>\n","      <td>name[The Punter],eatType[restaurant],food[Indi...</td>\n","      <td>[0.98941034, 0.99086845, 0.82358456, 0.985973,...</td>\n","      <td>[-0.10767646133899689, 0.2529870569705963, 0.2...</td>\n","    </tr>\n","  </tbody>\n","</table>\n","<p>3000 rows × 7 columns</p>\n","</div>"],"text/plain":["                                                           text  ...    en_embed_sentence_small_bert_L12_768_embeddings\n","origin_index                                                     ...                                                   \n","0             A coffee shop in the city centre area called B...  ...  [-0.1427491158246994, 0.5036071538925171, 0.07...\n","1                   Blue Spice is a coffee shop in city centre.  ...  [-0.20697341859340668, 0.5286431312561035, 0.2...\n","2             There is a coffee shop Blue Spice in the river...  ...  [0.005826675333082676, 0.49930453300476074, -0...\n","3             At the riverside, there is a coffee shop calle...  ...  [0.12191159278154373, 0.37966835498809814, 0.0...\n","4             The coffee shop Blue Spice is based near Crown...  ...  [-0.37350592017173767, 0.1885937601327896, 0.1...\n","...                                                         ...  ...                                                ...\n","2995          Near Express by Holiday Inn, in the riverside ...  ...  [0.0485222227871418, 0.2381688505411148, 0.227...\n","2996          In the riverside area, near Express by Holiday...  ...  [0.06879807263612747, 0.23580998182296753, 0.1...\n","2997          The Punter is a restaurant with Indian food in...  ...  [-0.12667560577392578, 0.22056235373020172, 0....\n","2998          The Punter is a low rated restaurant that serv...  ...  [-0.13057495653629303, 0.21937601268291473, 0....\n","2999          The Punter is a restaurant providing Indian fo...  ...  [-0.10767646133899689, 0.2529870569705963, 0.2...\n","\n","[3000 rows x 7 columns]"]},"metadata":{"tags":[]},"execution_count":4}]},{"cell_type":"code","metadata":{"id":"E7ah2LM6tIhG","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609531978935,"user_tz":-60,"elapsed":2137934,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"2636e995-5ef1-4457-895e-adcdf34f40c1"},"source":["from sklearn.preprocessing import MultiLabelBinarizer\n","from sklearn.metrics import classification_report\n","from sklearn.metrics import f1_score\n","from sklearn.metrics import roc_auc_score\n","preds.classifier_dl = preds.classifier_dl.astype(str)\n","mlb = MultiLabelBinarizer()\n","mlb = mlb.fit(preds.y.str.split(','))\n","y_true = mlb.transform(preds['y'].str.split(','))\n","y_pred = mlb.transform(preds.classifier_dl.str.join(',').str.split(','))\n","print(\"Classification report: \\n\", (classification_report(y_true, y_pred)))\n","print(\"F1 micro averaging:\",(f1_score(y_true, y_pred, average='micro')))\n","print(\"ROC: \",(roc_auc_score(y_true, y_pred, average=\"micro\")))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Classification report: \n","               precision    recall  f1-score   support\n","\n","           0       0.97      0.98      0.97       846\n","           1       0.99      0.98      0.98      1642\n","           2       0.93      0.70      0.80       300\n","           3       0.90      0.56      0.69       209\n","           4       0.91      0.72      0.81       246\n","           5       0.91      0.79      0.85       333\n","           6       0.95      0.84      0.90       288\n","           7       0.91      0.82      0.86       260\n","           8       0.99      0.99      0.99       267\n","           9       1.00      0.99      0.99      1275\n","          10       0.99      0.99      0.99      1458\n","          11       0.96      0.90      0.93       976\n","          12       0.95      0.97      0.96      1844\n","          13       1.00      0.99      0.99       492\n","          14       0.99      0.98      0.99       613\n","          15       0.97      0.98      0.98       632\n","          16       0.99      0.97      0.98       365\n","          17       1.00      0.97      0.99       145\n","          18       1.00      0.93      0.96        83\n","          19       1.00      0.98      0.99       136\n","          20       1.00      0.99      0.99       228\n","          21       1.00      0.69      0.82        36\n","          22       1.00      0.95      0.97        38\n","          23       1.00      0.50      0.67         4\n","          24       1.00      1.00      1.00       222\n","          25       0.99      1.00      0.99       240\n","          26       1.00      0.67      0.80         6\n","          27       1.00      0.94      0.97        32\n","          28       0.99      1.00      0.99       703\n","          29       1.00      1.00      1.00       524\n","          30       1.00      1.00      1.00       612\n","          31       1.00      0.94      0.97        88\n","          32       1.00      0.97      0.98       267\n","          33       1.00      1.00      1.00       297\n","          34       1.00      0.98      0.99        82\n","          35       1.00      0.89      0.94        18\n","          36       1.00      0.97      0.98       251\n","          37       1.00      1.00      1.00       348\n","          38       1.00      1.00      1.00       393\n","          39       1.00      0.99      1.00       390\n","          40       1.00      0.98      0.99       333\n","          41       1.00      1.00      1.00       794\n","          42       1.00      0.98      0.99        52\n","          43       1.00      0.50      0.67         8\n","          44       1.00      0.88      0.93         8\n","          45       0.00      0.00      0.00         4\n","          46       0.90      0.78      0.83       303\n","          47       0.89      0.70      0.78       425\n","          48       0.89      0.78      0.83       349\n","          49       0.93      0.80      0.86       373\n","          50       0.82      0.42      0.56       170\n","          51       0.95      0.67      0.79       220\n","\n","   micro avg       0.98      0.94      0.95     20228\n","   macro avg       0.96      0.86      0.90     20228\n","weighted avg       0.97      0.94      0.95     20228\n"," samples avg       0.98      0.94      0.96     20228\n","\n","F1 micro averaging: 0.9549113112810033\n","ROC:  0.9659676982287029\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"2BB-NwZUoHSe"},"source":["# 5. Lets save the model"]},{"cell_type":"code","metadata":{"id":"eLex095goHwm","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609535641300,"user_tz":-60,"elapsed":243837,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"458863e7-50f4-4cfe-dfdd-1b3edde4e8d8"},"source":["stored_model_path = './models/multi_classifier_dl_trained' \n","fitted_pipe.save(stored_model_path)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Stored model in ./models/multi_classifier_dl_trained\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"e_b2DPd4rCiU"},"source":["# 6. Lets load the model from HDD.\n","This makes Offlien NLU usage possible!   \n","You need to call nlu.load(path=path_to_the_pipe) to load a model/pipeline from disk."]},{"cell_type":"code","metadata":{"id":"SO4uz45MoRgp","colab":{"base_uri":"https://localhost:8080/","height":103},"executionInfo":{"status":"ok","timestamp":1609535674624,"user_tz":-60,"elapsed":274401,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"589912b1-32b5-4333-fe84-46cf40658451"},"source":["hdd_pipe = nlu.load(path=stored_model_path)\n","\n","preds = hdd_pipe.predict('Tesla plans to invest 10M into the ML sector')\n","preds"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>multi_classifier_classes</th>\n","      <th>document</th>\n","      <th>multi_classifier_confidences</th>\n","      <th>en_embed_sentence_small_bert_L12_768_embeddings</th>\n","    </tr>\n","    <tr>\n","      <th>origin_index</th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>[customer rating[high], customer rating[low], ...</td>\n","      <td>Tesla plans to invest 10M into the ML sector</td>\n","      <td>[0.9597453, 0.6497742, 0.986845, 0.5315694, 0....</td>\n","      <td>[0.15737222135066986, 0.2598555386066437, 0.85...</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["                                       multi_classifier_classes  ...    en_embed_sentence_small_bert_L12_768_embeddings\n","origin_index                                                     ...                                                   \n","0             [customer rating[high], customer rating[low], ...  ...  [0.15737222135066986, 0.2598555386066437, 0.85...\n","\n","[1 rows x 4 columns]"]},"metadata":{"tags":[]},"execution_count":7}]},{"cell_type":"code","metadata":{"id":"e0CVlkk9v6Qi","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1609535674627,"user_tz":-60,"elapsed":273679,"user":{"displayName":"Christian Kasim Loan","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjqAD-ircKP-s5Eh6JSdkDggDczfqQbJGU_IRb4Hw=s64","userId":"14469489166467359317"}},"outputId":"926c0a81-339a-49b8-e9ea-7f3ce049ca01"},"source":["hdd_pipe.print_info()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["The following parameters are configurable for this NLU pipeline (You can copy paste the examples) :\n",">>> pipe['document_assembler'] has settable params:\n","pipe['document_assembler'].setCleanupMode('shrink')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         | Info: possible values: disabled, inplace, inplace_full, shrink, shrink_full, each, each_full, delete_full | Currently set to : shrink\n",">>> pipe['regex_tokenizer'] has settable params:\n","pipe['regex_tokenizer'].setCaseSensitiveExceptions(True)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    | Info: Whether to care for case sensitiveness in exceptions | Currently set to : True\n","pipe['regex_tokenizer'].setTargetPattern('\\S+')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             | Info: pattern to grab from text as token candidates. Defaults \\S+ | Currently set to : \\S+\n","pipe['regex_tokenizer'].setMaxLength(99999)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 | Info: Set the maximum allowed length for each token | Currently set to : 99999\n","pipe['regex_tokenizer'].setMinLength(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     | Info: Set the minimum allowed length for each token | Currently set to : 0\n",">>> pipe['sentence_detector'] has settable params:\n","pipe['sentence_detector'].setCustomBounds([])                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               | Info: characters used to explicitly mark sentence bounds | Currently set to : []\n","pipe['sentence_detector'].setDetectLists(True)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              | Info: whether detect lists during sentence detection | Currently set to : True\n","pipe['sentence_detector'].setExplodeSentences(False)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        | Info: whether to explode each sentence into a different row, for better parallelization. Defaults to false. | Currently set to : False\n","pipe['sentence_detector'].setMaxLength(99999)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               | Info: Set the maximum allowed length for each sentence | Currently set to : 99999\n","pipe['sentence_detector'].setMinLength(0)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   | Info: Set the minimum allowed length for each sentence. | Currently set to : 0\n","pipe['sentence_detector'].setUseAbbreviations(True)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         | Info: whether to apply abbreviations at sentence detection | Currently set to : True\n","pipe['sentence_detector'].setUseCustomBoundsOnly(False)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     | Info: Only utilize custom bounds in sentence detection | Currently set to : False\n",">>> pipe['glove'] has settable params:\n","pipe['glove'].setBatchSize(32)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              | Info: Batch size. Large values allows faster processing but requires more memory. | Currently set to : 32\n","pipe['glove'].setCaseSensitive(False)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       | Info: whether to ignore case in tokens for embeddings matching | Currently set to : False\n","pipe['glove'].setDimension(768)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             | Info: Number of embedding dimensions | Currently set to : 768\n","pipe['glove'].setMaxSentenceLength(128)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     | Info: Max sentence length to process | Currently set to : 128\n","pipe['glove'].setIsLong(False)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              | Info: Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int. | Currently set to : False\n","pipe['glove'].setStorageRef('sent_small_bert_L12_768')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      | Info: unique reference name for identification | Currently set to : sent_small_bert_L12_768\n",">>> pipe['multi_classifier'] has settable params:\n","pipe['multi_classifier'].setThreshold(0.5)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  | Info: The minimum threshold for each label to be accepted. Default is 0.5 | Currently set to : 0.5\n","pipe['multi_classifier'].setClasses(['name[Clowns]', 'name[Cotto]', 'near[Burger King]', 'near[Crowne Plaza Hotel]', 'customer rating[high]', 'near[Avalon]', 'near[The Bakers]', 'near[Ranch]', 'eatType[restaurant]', 'near[All Bar One]', 'customer rating[low]', 'near[Café Sicilia]', 'food[Indian]', 'eatType[pub]', 'name[Green Man]', 'name[Strada]', 'eatType[coffee shop]', 'name[Loch Fyne]', 'customer rating[5 out of 5]', 'near[Express by Holiday Inn]', 'food[French]', 'name[The Mill]', 'food[Japanese]', 'name[The Plough]', 'name[Cocum]', 'name[The Phoenix]', 'priceRange[cheap]', 'near[Rainbow Vegetarian Café]', 'near[The Rice Boat]', 'customer rating[3 out of 5]', 'customer rating[1 out of 5]', 'name[The Cricketers]', 'area[riverside]', 'name[Blue Spice]', 'priceRange[£20-25]', 'priceRange[less than £20]', 'priceRange[moderate]', 'priceRange[high]', 'name[Giraffe]', 'customer rating[average]', 'food[Fast food]', 'near[Café Rouge]', 'area[city centre]', 'familyFriendly[no]', 'food[Chinese]', 'food[Italian]', 'near[Raja Indian Cuisine]', 'priceRange[more than £30]', 'name[The Punter]', 'food[English]', 'near[The Sorrento]', 'familyFriendly[yes]'])  | Info: get the tags used to trained this NerDLModel | Currently set to : ['name[Clowns]', 'name[Cotto]', 'near[Burger King]', 'near[Crowne Plaza Hotel]', 'customer rating[high]', 'near[Avalon]', 'near[The Bakers]', 'near[Ranch]', 'eatType[restaurant]', 'near[All Bar One]', 'customer rating[low]', 'near[Café Sicilia]', 'food[Indian]', 'eatType[pub]', 'name[Green Man]', 'name[Strada]', 'eatType[coffee shop]', 'name[Loch Fyne]', 'customer rating[5 out of 5]', 'near[Express by Holiday Inn]', 'food[French]', 'name[The Mill]', 'food[Japanese]', 'name[The Plough]', 'name[Cocum]', 'name[The Phoenix]', 'priceRange[cheap]', 'near[Rainbow Vegetarian Café]', 'near[The Rice Boat]', 'customer rating[3 out of 5]', 'customer rating[1 out of 5]', 'name[The Cricketers]', 'area[riverside]', 'name[Blue Spice]', 'priceRange[£20-25]', 'priceRange[less than £20]', 'priceRange[moderate]', 'priceRange[high]', 'name[Giraffe]', 'customer rating[average]', 'food[Fast food]', 'near[Café Rouge]', 'area[city centre]', 'familyFriendly[no]', 'food[Chinese]', 'food[Italian]', 'near[Raja Indian Cuisine]', 'priceRange[more than £30]', 'name[The Punter]', 'food[English]', 'near[The Sorrento]', 'familyFriendly[yes]']\n","pipe['multi_classifier'].setStorageRef('sent_small_bert_L12_768')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           | Info: unique reference name for identification | Currently set to : sent_small_bert_L12_768\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"M1LjAwJVJxun"},"source":[" "],"execution_count":null,"outputs":[]}]}